{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 第六节：优化技巧与解决方案升级"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6.2 深度学习解决方案：TextCNN建模"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.2.2 数据读取"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import lightgbm as lgb\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "\n",
    "from tqdm import tqdm_notebook\n",
    "from sklearn.preprocessing import LabelBinarizer,LabelEncoder\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "path  = '../security_data/'\n",
    "train = pd.read_csv(path + 'security_train.csv')\n",
    "test  = pd.read_csv(path + 'security_test.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tqdm import tqdm  \n",
    "\n",
    "class _Data_Preprocess:\n",
    "    def __init__(self):\n",
    "        self.int8_max = np.iinfo(np.int8).max\n",
    "        self.int8_min = np.iinfo(np.int8).min\n",
    "\n",
    "        self.int16_max = np.iinfo(np.int16).max\n",
    "        self.int16_min = np.iinfo(np.int16).min\n",
    "\n",
    "        self.int32_max = np.iinfo(np.int32).max\n",
    "        self.int32_min = np.iinfo(np.int32).min\n",
    "\n",
    "        self.int64_max = np.iinfo(np.int64).max\n",
    "        self.int64_min = np.iinfo(np.int64).min\n",
    "\n",
    "        self.float16_max = np.finfo(np.float16).max\n",
    "        self.float16_min = np.finfo(np.float16).min\n",
    "\n",
    "        self.float32_max = np.finfo(np.float32).max\n",
    "        self.float32_min = np.finfo(np.float32).min\n",
    "\n",
    "        self.float64_max = np.finfo(np.float64).max\n",
    "        self.float64_min = np.finfo(np.float64).min\n",
    "\n",
    "    def _get_type(self, min_val, max_val, types):\n",
    "        if types == 'int':\n",
    "            if max_val <= self.int8_max and min_val >= self.int8_min:\n",
    "                return np.int8\n",
    "            elif max_val <= self.int16_max <= max_val and min_val >= self.int16_min:\n",
    "                return np.int16\n",
    "            elif max_val <= self.int32_max and min_val >= self.int32_min:\n",
    "                return np.int32\n",
    "            return None\n",
    "\n",
    "        elif types == 'float':\n",
    "            if max_val <= self.float16_max and min_val >= self.float16_min:\n",
    "                return np.float16\n",
    "            if max_val <= self.float32_max and min_val >= self.float32_min:\n",
    "                return np.float32\n",
    "            if max_val <= self.float64_max and min_val >= self.float64_min:\n",
    "                return np.float64\n",
    "            return None\n",
    "\n",
    "    def _memory_process(self, df):\n",
    "        init_memory = df.memory_usage().sum() / 1024 ** 2 / 1024\n",
    "        print('Original data occupies {} GB memory.'.format(init_memory))\n",
    "        df_cols = df.columns\n",
    "\n",
    "          \n",
    "        for col in tqdm_notebook(df_cols):\n",
    "            try:\n",
    "                if 'float' in str(df[col].dtypes):\n",
    "                    max_val = df[col].max()\n",
    "                    min_val = df[col].min()\n",
    "                    trans_types = self._get_type(min_val, max_val, 'float')\n",
    "                    if trans_types is not None:\n",
    "                        df[col] = df[col].astype(trans_types)\n",
    "                elif 'int' in str(df[col].dtypes):\n",
    "                    max_val = df[col].max()\n",
    "                    min_val = df[col].min()\n",
    "                    trans_types = self._get_type(min_val, max_val, 'int')\n",
    "                    if trans_types is not None:\n",
    "                        df[col] = df[col].astype(trans_types)\n",
    "            except:\n",
    "                print(' Can not do any process for column, {}.'.format(col)) \n",
    "        afterprocess_memory = df.memory_usage().sum() / 1024 ** 2 / 1024\n",
    "        print('After processing, the data occupies {} GB memory.'.format(afterprocess_memory))\n",
    "        return df\n",
    "\n",
    "memory_process = _Data_Preprocess()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>file_id</th>\n",
       "      <th>label</th>\n",
       "      <th>api</th>\n",
       "      <th>tid</th>\n",
       "      <th>index</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>LdrLoadDll</td>\n",
       "      <td>2488</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>LdrGetProcedureAddress</td>\n",
       "      <td>2488</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>LdrGetProcedureAddress</td>\n",
       "      <td>2488</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>LdrGetProcedureAddress</td>\n",
       "      <td>2488</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>5</td>\n",
       "      <td>LdrGetProcedureAddress</td>\n",
       "      <td>2488</td>\n",
       "      <td>4</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   file_id  label                     api   tid  index\n",
       "0        1      5              LdrLoadDll  2488      0\n",
       "1        1      5  LdrGetProcedureAddress  2488      1\n",
       "2        1      5  LdrGetProcedureAddress  2488      2\n",
       "3        1      5  LdrGetProcedureAddress  2488      3\n",
       "4        1      5  LdrGetProcedureAddress  2488      4"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.2.3 数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# （字符串转化为数字）\n",
    "unique_api = train['api'].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "api2index = {item:(i+1) for i,item in enumerate(unique_api)}\n",
    "index2api = {(i+1):item for i,item in enumerate(unique_api)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "train['api_idx'] = train['api'].map(api2index)\n",
    "test['api_idx']  = test['api'].map(api2index)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 获取每个文件对应的字符串序列\n",
    "def get_sequence(df,period_idx):\n",
    "    seq_list = []\n",
    "    for _id,begin in enumerate(period_idx[:-1]):\n",
    "        seq_list.append(df.iloc[begin:period_idx[_id+1]]['api_idx'].values)\n",
    "    seq_list.append(df.iloc[period_idx[-1]:]['api_idx'].values)\n",
    "    return seq_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_period_idx = train.file_id.drop_duplicates(keep='first').index.values\n",
    "test_period_idx  = test.file_id.drop_duplicates(keep='first').index.values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df = train[['file_id','label']].drop_duplicates(keep='first')\n",
    "test_df  = test[['file_id']].drop_duplicates(keep='first')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_df['seq'] = get_sequence(train,train_period_idx)\n",
    "test_df['seq']  = get_sequence(test,test_period_idx)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.2.4 TextCNN网络结构"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "from keras.preprocessing.text import Tokenizer\n",
    "from keras.preprocessing.sequence import pad_sequences\n",
    "from keras.layers import Dense, Input, LSTM, Lambda, Embedding, Dropout, Activation,GRU,Bidirectional\n",
    "from keras.layers import Conv1D,Conv2D,MaxPooling2D,GlobalAveragePooling1D,GlobalMaxPooling1D, MaxPooling1D, Flatten\n",
    "from keras.layers import CuDNNGRU, CuDNNLSTM, SpatialDropout1D\n",
    "from keras.layers.merge import concatenate, Concatenate, Average, Dot, Maximum, Multiply, Subtract, average\n",
    "from keras.models import Model\n",
    "from keras.optimizers import RMSprop,Adam\n",
    "from keras.layers.normalization import BatchNormalization\n",
    "from keras.callbacks import EarlyStopping, ModelCheckpoint\n",
    "from keras.optimizers import SGD\n",
    "from keras import backend as K\n",
    "from sklearn.decomposition import TruncatedSVD, NMF, LatentDirichletAllocation\n",
    "from keras.layers import SpatialDropout1D\n",
    "from keras.layers.wrappers import Bidirectional"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def TextCNN(max_len,max_cnt,embed_size, num_filters,kernel_size,conv_action, mask_zero):\n",
    "    _input = Input(shape=(max_len,), dtype='int32')\n",
    "    _embed = Embedding(max_cnt, embed_size, input_length=max_len, mask_zero=mask_zero)(_input)\n",
    "    _embed = SpatialDropout1D(0.15)(_embed)\n",
    "    warppers = []\n",
    "    \n",
    "    for _kernel_size in kernel_size:\n",
    "        conv1d = Conv1D(filters=num_filters, kernel_size=_kernel_size, activation=conv_action)(_embed)\n",
    "        warppers.append(GlobalMaxPooling1D()(conv1d))\n",
    "                        \n",
    "    fc = concatenate(warppers)\n",
    "    fc = Dropout(0.5)(fc)\n",
    "    #fc = BatchNormalization()(fc)\n",
    "    fc = Dense(256, activation='relu')(fc)\n",
    "    fc = Dropout(0.25)(fc)\n",
    "    #fc = BatchNormalization()(fc) \n",
    "    preds = Dense(8, activation = 'softmax')(fc)\n",
    "    \n",
    "    model = Model(inputs=_input, outputs=preds)\n",
    "    \n",
    "    model.compile(loss='categorical_crossentropy',\n",
    "        optimizer='adam',\n",
    "        metrics=['accuracy'])\n",
    "    return model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_labels = pd.get_dummies(train_df.label).values\n",
    "train_seq    = pad_sequences(train_df.seq.values, maxlen = 6000)\n",
    "test_seq     = pad_sequences(test_df.seq.values, maxlen = 6000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.2.5 TextCNN训练和预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import StratifiedKFold,KFold \n",
    "skf = KFold(n_splits=5, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_len     = 6000\n",
    "max_cnt     = 295\n",
    "embed_size  = 256\n",
    "num_filters = 64\n",
    "kernel_size = [2,4,6,8,10,12,14]\n",
    "conv_action = 'relu'\n",
    "mask_zero   = False\n",
    "TRAIN       = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FOLD: \n",
      "2778 11109\n",
      "Train on 11109 samples, validate on 2778 samples\n",
      "Epoch 1/100\n",
      "11109/11109 [==============================] - 75s 7ms/step - loss: 0.8165 - acc: 0.7370 - val_loss: 0.4825 - val_acc: 0.8485\n",
      "Epoch 2/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.4772 - acc: 0.8499 - val_loss: 0.4141 - val_acc: 0.8625\n",
      "Epoch 3/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.4172 - acc: 0.8673 - val_loss: 0.3785 - val_acc: 0.8780\n",
      "Epoch 4/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.3768 - acc: 0.8769 - val_loss: 0.3821 - val_acc: 0.8726\n",
      "Epoch 5/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.3568 - acc: 0.8831 - val_loss: 0.3932 - val_acc: 0.8783\n",
      "Epoch 6/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.3388 - acc: 0.8893 - val_loss: 0.3566 - val_acc: 0.8902\n",
      "Epoch 7/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.3179 - acc: 0.8968 - val_loss: 0.3553 - val_acc: 0.8902\n",
      "Epoch 8/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.3050 - acc: 0.8991 - val_loss: 0.3590 - val_acc: 0.8870\n",
      "Epoch 9/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.2913 - acc: 0.9006 - val_loss: 0.3593 - val_acc: 0.8909\n",
      "Epoch 10/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.2812 - acc: 0.9047 - val_loss: 0.3528 - val_acc: 0.8906\n",
      "Epoch 11/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.2633 - acc: 0.9054 - val_loss: 0.3608 - val_acc: 0.8823\n",
      "Epoch 12/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.2665 - acc: 0.9103 - val_loss: 0.3589 - val_acc: 0.8859\n",
      "Epoch 13/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.2495 - acc: 0.9097 - val_loss: 0.3570 - val_acc: 0.8909\n",
      "2778/2778 [==============================] - 4s 1ms/step\n",
      "12955/12955 [==============================] - 13s 980us/step\n",
      "FOLD: \n",
      "2778 11109\n",
      "Train on 11109 samples, validate on 2778 samples\n",
      "Epoch 1/100\n",
      "11109/11109 [==============================] - 65s 6ms/step - loss: 0.8297 - acc: 0.7290 - val_loss: 0.4925 - val_acc: 0.8463\n",
      "Epoch 2/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.4808 - acc: 0.8442 - val_loss: 0.4115 - val_acc: 0.8690\n",
      "Epoch 3/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.4149 - acc: 0.8643 - val_loss: 0.4037 - val_acc: 0.8715\n",
      "Epoch 4/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.3799 - acc: 0.8774 - val_loss: 0.3798 - val_acc: 0.8841\n",
      "Epoch 5/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.3530 - acc: 0.8850 - val_loss: 0.3773 - val_acc: 0.8870\n",
      "Epoch 6/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.3291 - acc: 0.8924 - val_loss: 0.3676 - val_acc: 0.8855\n",
      "Epoch 7/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.3115 - acc: 0.8959 - val_loss: 0.3773 - val_acc: 0.8888\n",
      "Epoch 8/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.3032 - acc: 0.8998 - val_loss: 0.3518 - val_acc: 0.8891\n",
      "Epoch 9/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.2892 - acc: 0.9027 - val_loss: 0.3655 - val_acc: 0.8920\n",
      "Epoch 10/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.2792 - acc: 0.9056 - val_loss: 0.3615 - val_acc: 0.8906\n",
      "Epoch 11/100\n",
      "11109/11109 [==============================] - 64s 6ms/step - loss: 0.2736 - acc: 0.9085 - val_loss: 0.3719 - val_acc: 0.8924\n",
      "2778/2778 [==============================] - 3s 1ms/step\n",
      "12955/12955 [==============================] - 12s 951us/step\n",
      "FOLD: \n",
      "2777 11110\n",
      "Train on 11110 samples, validate on 2777 samples\n",
      "Epoch 1/100\n",
      "11110/11110 [==============================] - 67s 6ms/step - loss: 0.8388 - acc: 0.7239 - val_loss: 0.4323 - val_acc: 0.8657\n",
      "Epoch 2/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.4969 - acc: 0.8418 - val_loss: 0.3881 - val_acc: 0.8783\n",
      "Epoch 3/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.4276 - acc: 0.8631 - val_loss: 0.3587 - val_acc: 0.8855\n",
      "Epoch 4/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.3882 - acc: 0.8770 - val_loss: 0.3542 - val_acc: 0.8898\n",
      "Epoch 5/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.3590 - acc: 0.8875 - val_loss: 0.3640 - val_acc: 0.8920\n",
      "Epoch 6/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.3437 - acc: 0.8851 - val_loss: 0.3445 - val_acc: 0.8992\n",
      "Epoch 7/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.3235 - acc: 0.8912 - val_loss: 0.3552 - val_acc: 0.8923\n",
      "Epoch 8/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.3079 - acc: 0.8986 - val_loss: 0.3491 - val_acc: 0.8927\n",
      "Epoch 9/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.2937 - acc: 0.9035 - val_loss: 0.3370 - val_acc: 0.9003\n",
      "Epoch 10/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.2897 - acc: 0.9018 - val_loss: 0.3523 - val_acc: 0.8959\n",
      "Epoch 11/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.2788 - acc: 0.9066 - val_loss: 0.3519 - val_acc: 0.8967\n",
      "Epoch 12/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.2653 - acc: 0.9131 - val_loss: 0.3571 - val_acc: 0.8952\n",
      "2777/2777 [==============================] - 3s 1ms/step\n",
      "12955/12955 [==============================] - 12s 955us/step\n",
      "FOLD: \n",
      "2777 11110\n",
      "Train on 11110 samples, validate on 2777 samples\n",
      "Epoch 1/100\n",
      "11110/11110 [==============================] - 66s 6ms/step - loss: 0.8286 - acc: 0.7326 - val_loss: 0.4647 - val_acc: 0.8596\n",
      "Epoch 2/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.4807 - acc: 0.8524 - val_loss: 0.4081 - val_acc: 0.8704\n",
      "Epoch 3/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.4233 - acc: 0.8656 - val_loss: 0.3920 - val_acc: 0.8808\n",
      "Epoch 4/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.3771 - acc: 0.8819 - val_loss: 0.3767 - val_acc: 0.8804\n",
      "Epoch 5/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.3529 - acc: 0.8861 - val_loss: 0.3990 - val_acc: 0.8725\n",
      "Epoch 6/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.3353 - acc: 0.8900 - val_loss: 0.3776 - val_acc: 0.8822\n",
      "Epoch 7/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.3234 - acc: 0.8935 - val_loss: 0.3717 - val_acc: 0.8855\n",
      "Epoch 8/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.3012 - acc: 0.9010 - val_loss: 0.3758 - val_acc: 0.8848\n",
      "Epoch 9/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.2907 - acc: 0.9011 - val_loss: 0.3656 - val_acc: 0.8862\n",
      "Epoch 10/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.2852 - acc: 0.9022 - val_loss: 0.3676 - val_acc: 0.8858\n",
      "Epoch 11/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.2683 - acc: 0.9085 - val_loss: 0.3630 - val_acc: 0.8862\n",
      "Epoch 12/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.2595 - acc: 0.9091 - val_loss: 0.3768 - val_acc: 0.8884\n",
      "Epoch 13/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.2533 - acc: 0.9140 - val_loss: 0.3817 - val_acc: 0.8822\n",
      "Epoch 14/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.2464 - acc: 0.9155 - val_loss: 0.3757 - val_acc: 0.8862\n",
      "2777/2777 [==============================] - 3s 1ms/step\n",
      "12955/12955 [==============================] - 12s 949us/step\n",
      "FOLD: \n",
      "2777 11110\n",
      "Train on 11110 samples, validate on 2777 samples\n",
      "Epoch 1/100\n",
      "11110/11110 [==============================] - 65s 6ms/step - loss: 0.8168 - acc: 0.7315 - val_loss: 0.4718 - val_acc: 0.8567\n",
      "Epoch 2/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.4880 - acc: 0.8459 - val_loss: 0.4047 - val_acc: 0.8711\n",
      "Epoch 3/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.4224 - acc: 0.8674 - val_loss: 0.3871 - val_acc: 0.8732\n",
      "Epoch 4/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.3900 - acc: 0.8728 - val_loss: 0.3676 - val_acc: 0.8812\n",
      "Epoch 5/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.3581 - acc: 0.8846 - val_loss: 0.3713 - val_acc: 0.8819\n",
      "Epoch 6/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.3391 - acc: 0.8890 - val_loss: 0.3542 - val_acc: 0.8905\n",
      "Epoch 7/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.3158 - acc: 0.8975 - val_loss: 0.3610 - val_acc: 0.8902\n",
      "Epoch 8/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.3074 - acc: 0.8986 - val_loss: 0.3520 - val_acc: 0.8887\n",
      "Epoch 9/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.2905 - acc: 0.9026 - val_loss: 0.3588 - val_acc: 0.8941\n",
      "Epoch 10/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.2795 - acc: 0.9032 - val_loss: 0.3417 - val_acc: 0.8923\n",
      "Epoch 11/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.2747 - acc: 0.9044 - val_loss: 0.3456 - val_acc: 0.8912\n",
      "Epoch 12/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.2546 - acc: 0.9131 - val_loss: 0.3517 - val_acc: 0.8902\n",
      "Epoch 13/100\n",
      "11110/11110 [==============================] - 64s 6ms/step - loss: 0.2483 - acc: 0.9144 - val_loss: 0.3785 - val_acc: 0.8909\n",
      "2777/2777 [==============================] - 3s 1ms/step\n",
      "12955/12955 [==============================] - 12s 949us/step\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0,1\"\n",
    "meta_train = np.zeros(shape = (len(train_seq),8))\n",
    "meta_test = np.zeros(shape = (len(test_seq),8))\n",
    "FLAG = True\n",
    "i = 0\n",
    "for tr_ind,te_ind in skf.split(train_labels):\n",
    "    i +=1\n",
    "    print('FOLD: '.format(i))\n",
    "    print(len(te_ind),len(tr_ind)) \n",
    "    model_name = 'benchmark_textcnn_fold_'+str(i)\n",
    "    X_train,X_train_label = train_seq[tr_ind],train_labels[tr_ind]\n",
    "    X_val,X_val_label     = train_seq[te_ind],train_labels[te_ind]\n",
    "    \n",
    "    model = TextCNN(max_len,max_cnt,embed_size,num_filters,kernel_size,conv_action,mask_zero)\n",
    "    \n",
    "    model_save_path = './NN/%s_%s.hdf5'%(model_name,embed_size)\n",
    "    early_stopping =EarlyStopping(monitor='val_loss', patience=3)\n",
    "    model_checkpoint = ModelCheckpoint(model_save_path, save_best_only=True, save_weights_only=True)\n",
    "    if TRAIN and FLAG:\n",
    "        model.fit(X_train,X_train_label,validation_data=(X_val,X_val_label),epochs=100,batch_size=64,shuffle=True,callbacks=[early_stopping,model_checkpoint] )\n",
    "    \n",
    "    model.load_weights(model_save_path)\n",
    "    pred_val = model.predict(X_val,batch_size=128,verbose=1)\n",
    "    pred_test = model.predict(test_seq,batch_size=128,verbose=1)\n",
    "    \n",
    "    meta_train[te_ind] = pred_val\n",
    "    meta_test += pred_test\n",
    "    K.clear_session()\n",
    "meta_test /= 5.0 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 6.2.6 结果提交"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_df['prob0'] = 0\n",
    "test_df['prob1'] = 0\n",
    "test_df['prob2'] = 0\n",
    "test_df['prob3'] = 0\n",
    "test_df['prob4'] = 0\n",
    "test_df['prob5'] = 0\n",
    "test_df['prob6'] = 0\n",
    "test_df['prob7'] = 0\n",
    "\n",
    "test_df[['prob0','prob1','prob2','prob3','prob4','prob5','prob6','prob7']] = meta_test\n",
    "test_df[['file_id','prob0','prob1','prob2','prob3','prob4','prob5','prob6','prob7']].to_csv('nn_baseline_5fold.csv',index = None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  },
  "toc": {
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": true,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
